{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2325502b",
   "metadata": {},
   "source": [
    "# Customize AutoMM\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/customization.ipynb)\n",
    "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/customization.ipynb)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "AutoMM has a powerful yet easy-to-use configuration design.\n",
    "This tutorial walks you through various AutoMM configurations to empower you the customization flexibility. Specifically, AutoMM configurations consist of several parts:\n",
    "\n",
    "- optimization\n",
    "- environment\n",
    "- model\n",
    "- data\n",
    "- distiller\n",
    "\n",
    "## Optimization\n",
    "\n",
    "### optimization.learning_rate\n",
    "\n",
    "Learning rate."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f349cfb",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.learning_rate\": 1.0e-4})\n",
    "# set learning rate to 5.0e-4\n",
    "predictor.fit(hyperparameters={\"optimization.learning_rate\": 5.0e-4})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ea63ec5",
   "metadata": {},
   "source": [
    "### optimization.optim_type\n",
    "\n",
    "Optimizer type.\n",
    "\n",
    "- `\"sgd\"`: stochastic gradient descent with momentum.\n",
    "- `\"adam\"`: a stochastic gradient descent method that is based on adaptive estimation of first-order and second-order moments. See [this paper](https://arxiv.org/abs/1412.6980) for details.\n",
    "- `\"adamw\"`: improves adam by decoupling the weight decay from the optimization step. See [this paper](https://arxiv.org/abs/1711.05101) for details."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37fcd9e2",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.optim_type\": \"adamw\"})\n",
    "# use optimizer adam\n",
    "predictor.fit(hyperparameters={\"optimization.optim_type\": \"adam\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aeb40726",
   "metadata": {},
   "source": [
    "### optimization.weight_decay\n",
    "\n",
    "Weight decay."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "852daaee",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.weight_decay\": 1.0e-3})\n",
    "# set weight decay to 1.0e-4\n",
    "predictor.fit(hyperparameters={\"optimization.weight_decay\": 1.0e-4})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7006101a",
   "metadata": {},
   "source": [
    "### optimization.lr_decay\n",
    "\n",
    "Later layers can have larger learning rates than the earlier layers. The last/head layer\n",
    "has the largest learning rate `optimization.learning_rate`. For a model with `n` layers, layer `i` has learning rate `optimization.learning_rate * optimization.lr_decay^(n-i)`. To use one uniform learning rate, simply set the learning rate decay to `1`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adb57179",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.lr_decay\": 0.9})\n",
    "# turn off learning rate decay\n",
    "predictor.fit(hyperparameters={\"optimization.lr_decay\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59632914",
   "metadata": {},
   "source": [
    "### optimization.lr_mult\n",
    "\n",
    "While we are using two_stages lr choice,\n",
    "The last/head layer has the largest learning rate `optimization.learning_rate` * `optimization.lr_mult`.\n",
    "And other layers has normal learning rate `optimization.learning_rate`.\n",
    "To use one uniform learning rate, simply set the learning rate multiple to `1`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1770634e",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.lr_mult\": 1})\n",
    "# turn on two-stage lr for 10 times learning rate in head layer\n",
    "predictor.fit(hyperparameters={\"optimization.lr_mult\": 10})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60b67198",
   "metadata": {},
   "source": [
    "### optimization.lr_choice\n",
    "\n",
    "We may want different layers to have different lr,\n",
    "here we have strategy `two_stages` lr choice (see `optimization.lr_mult` section for more details),\n",
    "or `layerwise_decay` lr choice (see `optimization.lr_decay` section for more details).\n",
    "To use one uniform learning rate, simply set this to `\"\"`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3dd4814",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.lr_choice\": \"layerwise_decay\"})\n",
    "# turn on two-stage lr choice\n",
    "predictor.fit(hyperparameters={\"optimization.lr_choice\": \"two_stages\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31355e32",
   "metadata": {},
   "source": [
    "### optimization.lr_schedule\n",
    "\n",
    "Learning rate schedule.\n",
    "\n",
    "- `\"cosine_decay\"`: the decay of learning rate follows the cosine curve.\n",
    "- `\"polynomial_decay\"`: the learning rate is decayed based on polynomial functions.\n",
    "- `\"linear_decay\"`: linearly decays the learing rate."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "100a6a22",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.lr_schedule\": \"cosine_decay\"})\n",
    "# use polynomial decay\n",
    "predictor.fit(hyperparameters={\"optimization.lr_schedule\": \"polynomial_decay\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ce9a4af",
   "metadata": {},
   "source": [
    "### optimization.max_epochs\n",
    "\n",
    "Stop training once this number of epochs is reached."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30a8c032",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.max_epochs\": 10})\n",
    "# train 20 epochs\n",
    "predictor.fit(hyperparameters={\"optimization.max_epochs\": 20})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53264b7c",
   "metadata": {},
   "source": [
    "### optimization.max_steps\n",
    "\n",
    "Stop training after this number of steps. Training will stop if `optimization.max_steps` or `optimization.max_epochs` have reached (earliest).\n",
    "By default, we disable `optimization.max_steps` by setting it to -1."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2482fc3e",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.max_steps\": -1})\n",
    "# train 100 steps\n",
    "predictor.fit(hyperparameters={\"optimization.max_steps\": 100})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd6fa991",
   "metadata": {},
   "source": [
    "### optimization.warmup_steps\n",
    "\n",
    "Warm up the learning rate from 0 to `optimization.learning_rate` within this percentage of steps at the beginning of training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34d3a967",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.warmup_steps\": 0.1})\n",
    "# do learning rate warmup in the first 20% steps.\n",
    "predictor.fit(hyperparameters={\"optimization.warmup_steps\": 0.2})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80452db2",
   "metadata": {},
   "source": [
    "### optimization.patience\n",
    "\n",
    "Stop training after this number of checks with no improvement. The check frequency is controlled by `optimization.val_check_interval`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3bcd482",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.patience\": 10})\n",
    "# set patience to 5 checks\n",
    "predictor.fit(hyperparameters={\"optimization.patience\": 5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2765c653",
   "metadata": {},
   "source": [
    "### optimization.val_check_interval\n",
    "\n",
    "How often within one training epoch to check the validation set. Can specify as float or int.\n",
    "\n",
    "- pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch.\n",
    "- pass an int to check after a fixed number of training batches."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ee8c226",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.val_check_interval\": 0.5})\n",
    "# check validation set 4 times during a training epoch\n",
    "predictor.fit(hyperparameters={\"optimization.val_check_interval\": 0.25})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e28553d8",
   "metadata": {},
   "source": [
    "### optimization.gradient_clip_algorithm\n",
    "\n",
    "The gradient clipping algorithm to use. Support to clip gradients by value or norm."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7526131f",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.gradient_clip_algorithm\": \"norm\"})\n",
    "# clip gradients by value\n",
    "predictor.fit(hyperparameters={\"optimization.gradient_clip_algorithm\": \"value\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c29d0454",
   "metadata": {},
   "source": [
    "### optimization.gradient_clip_val\n",
    "\n",
    "Gradient clipping value, which can be the absolute value or gradient norm depending on the choice of `optimization.gradient_clip_algorithm`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50e90350",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.gradient_clip_val\": 1})\n",
    "# cap the gradients to 5\n",
    "predictor.fit(hyperparameters={\"optimization.gradient_clip_val\": 5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02d07866",
   "metadata": {},
   "source": [
    "### optimization.track_grad_norm\n",
    "\n",
    "Track the p-norm of gradients during training. May be set to ‘inf’ infinity-norm. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b60c371",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM (no tracking)\n",
    "predictor.fit(hyperparameters={\"optimization.track_grad_norm\": -1})\n",
    "# track the 2-norm\n",
    "predictor.fit(hyperparameters={\"optimization.track_grad_norm\": 2})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abe87d32",
   "metadata": {},
   "source": [
    "### optimization.log_every_n_steps\n",
    "\n",
    "How often to log within steps."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f5fe49c",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.log_every_n_steps\": 10})\n",
    "# log once every 50 steps\n",
    "predictor.fit(hyperparameters={\"optimization.log_every_n_steps\": 50})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28f30c7e",
   "metadata": {},
   "source": [
    "### optimization.top_k\n",
    "\n",
    "Based on the validation score, choose top k model checkpoints to do model averaging."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17258cf4",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.top_k\": 3})\n",
    "# use top 5 checkpoints\n",
    "predictor.fit(hyperparameters={\"optimization.top_k\": 5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a9f1cfc",
   "metadata": {},
   "source": [
    "### optimization.top_k_average_method\n",
    "\n",
    "Use what strategy to average the top k model checkpoints.\n",
    "\n",
    "- `\"greedy_soup\"`: tries to add the checkpoints from best to worst into the averaging pool and stop if the averaged checkpoint performance decreases. See [the paper](https://arxiv.org/pdf/2203.05482.pdf) for details.\n",
    "- `\"uniform_soup\"`: averages all the top k checkpoints as the final checkpoint.\n",
    "- `\"best\"`: picks the checkpoint with the best validation performance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8236ab40",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.top_k_average_method\": \"greedy_soup\"})\n",
    "# average all the top k checkpoints\n",
    "predictor.fit(hyperparameters={\"optimization.top_k_average_method\": \"uniform_soup\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "652f15b4",
   "metadata": {},
   "source": [
    "### optimization.efficient_finetune\n",
    "\n",
    "Options for parameter-efficient finetuning. Parameter-efficient finetuning means to finetune only a small portion of parameters instead of the whole pretrained backbone.\n",
    "\n",
    "- `\"bit_fit\"`: bias parameters only. See [this paper](https://arxiv.org/pdf/2106.10199.pdf) for details.\n",
    "- `\"norm_fit\"`: normalization parameters + bias parameters. See [this paper](https://arxiv.org/pdf/2003.00152.pdf) for details.\n",
    "- `\"lora\"`: LoRA Adaptors. See [this paper](https://arxiv.org/pdf/2106.09685.pdf) for details.\n",
    "- `\"lora_bias\"`: LoRA Adaptors + bias parameters.\n",
    "- `\"lora_norm\"`: LoRA Adaptors + normalization parameters + bias parameters.\n",
    "- `\"ia3\"`: IA3 algorithm. See [this paper](https://arxiv.org/abs/2205.05638) for details.\n",
    "- `\"ia3_bias\"`: IA3 + bias parameters.\n",
    "- `\"ia3_norm\"`: IA3 + normalization parameters + bias parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a072add",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.efficient_finetune\": None})\n",
    "# finetune only bias parameters\n",
    "predictor.fit(hyperparameters={\"optimization.efficient_finetune\": \"bit_fit\"})\n",
    "# finetune with IA3 + BitFit\n",
    "predictor.fit(hyperparameters={\"optimization.efficient_finetune\": \"ia3_bias\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a138ce84-1462-4d67-a82e-80e829a7a57d",
   "metadata": {},
   "source": [
    "### optimization.skip_final_val\n",
    "\n",
    "Whether to skip the final validation after training is signaled to stop."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6830fa7f-d6ef-4578-9efd-16923fca0918",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"optimization.skip_final_val\": False})\n",
    "# skip the final validation\n",
    "predictor.fit(hyperparameters={\"optimization.skip_final_val\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3cfbd99",
   "metadata": {},
   "source": [
    "## Environment\n",
    "\n",
    "### env.num_gpus\n",
    "\n",
    "The number of gpus to use. If given -1, we count the GPUs by `env.num_gpus = torch.cuda.device_count()`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6908008",
   "metadata": {},
   "source": [
    "```\n",
    "# by default, all available gpus are used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.num_gpus\": -1})\n",
    "# use 1 gpu only\n",
    "predictor.fit(hyperparameters={\"env.num_gpus\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5e3c075",
   "metadata": {},
   "source": [
    "### env.per_gpu_batch_size\n",
    "\n",
    "The batch size for each GPU."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "692d5f4f",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.per_gpu_batch_size\": 8})\n",
    "# use batch size 16 per GPU\n",
    "predictor.fit(hyperparameters={\"env.per_gpu_batch_size\": 16})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a23b8cc",
   "metadata": {},
   "source": [
    "### env.batch_size\n",
    "\n",
    "The batch size to use in each step of training. If `env.batch_size` is larger than `env.per_gpu_batch_size * env.num_gpus`, we accumulate gradients to reach the effective `env.batch_size` before performing one optimization step. The accumulation steps are calculated by `env.batch_size // (env.per_gpu_batch_size * env.num_gpus)`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14b5a0c2",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.batch_size\": 128})\n",
    "# use batch size 256\n",
    "predictor.fit(hyperparameters={\"env.batch_size\": 256})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdf820c7",
   "metadata": {},
   "source": [
    "### env.eval_batch_size_ratio\n",
    "\n",
    "Prediction or evaluation uses a larger per gpu batch size `env.per_gpu_batch_size * env.eval_batch_size_ratio`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3d8a8ca",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.eval_batch_size_ratio\": 4})\n",
    "# use 2x per gpu batch size during prediction or evaluation\n",
    "predictor.fit(hyperparameters={\"env.eval_batch_size_ratio\": 2})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d492098",
   "metadata": {},
   "source": [
    "### env.precision\n",
    "\n",
    "Support either double (`64`), float (`32`), bfloat16 (`\"bf16\"`), or half (`16`) precision training.\n",
    "\n",
    "Half precision, or mixed precision, is the combined use of 32 and 16 bit floating points to reduce memory footprint during model training. This can result in improved performance, achieving +3x speedups on modern GPUs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c348024",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.precision\": 16})\n",
    "# use bfloat16\n",
    "predictor.fit(hyperparameters={\"env.precision\": \"bf16\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3aa8d934",
   "metadata": {},
   "source": [
    "### env.num_workers\n",
    "\n",
    "The number of worker processes used by the Pytorch dataloader in training. Note that more workers don't always bring speedup especially when `env.strategy = \"ddp_spawn\"`.\n",
    "For more details, see the guideline [here](https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html#distributed-data-parallel)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "789bed40",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.num_workers\": 2})\n",
    "# use 4 workers in the training dataloader\n",
    "predictor.fit(hyperparameters={\"env.num_workers\": 4})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86faccf9",
   "metadata": {},
   "source": [
    "### env.num_workers_evaluation\n",
    "\n",
    "The number of worker processes used by the Pytorch dataloader in prediction or evaluation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4040737",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.num_workers_evaluation\": 2})\n",
    "# use 4 workers in the prediction/evaluation dataloader\n",
    "predictor.fit(hyperparameters={\"env.num_workers_evaluation\": 4})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4ea42b0",
   "metadata": {},
   "source": [
    "### env.strategy\n",
    "\n",
    "Distributed training mode.\n",
    "\n",
    "- `\"dp\"`: data parallel.\n",
    "- `\"ddp\"`: distributed data parallel (python script based).\n",
    "- `\"ddp_spawn\"`: distributed data parallel (spawn based).\n",
    "\n",
    "See [here](https://lightning.ai/docs/pytorch/stable/extensions/strategy.html#selecting-a-built-in-strategy) for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab1c3e6f",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"env.strategy\": \"ddp_spawn\"})\n",
    "# use ddp during training\n",
    "predictor.fit(hyperparameters={\"env.strategy\": \"ddp\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2b4a051",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n",
    "### model.names\n",
    "\n",
    "Choose what types of models to use.\n",
    "\n",
    "- `\"hf_text\"`: the pretrained text models from [Huggingface](https://huggingface.co/).\n",
    "- `\"timm_image\"`: the pretrained image models from [TIMM](https://github.com/rwightman/pytorch-image-models/tree/master/timm/models).\n",
    "- `\"clip\"`: the pretrained CLIP models.\n",
    "- `\"categorical_mlp\"`: MLP for categorical data.\n",
    "- `\"numerical_mlp\"`: MLP for numerical data.\n",
    "- `\"categorical_transformer\"`: [FT-Transformer](https://arxiv.org/pdf/2106.11959.pdf) for categorical data.\n",
    "- `\"numerical_transformer\"`: [FT-Transformer](https://arxiv.org/pdf/2106.11959.pdf) for numerical data.\n",
    "- `\"fusion_mlp\"`: MLP-based fusion for features from multiple backbones.\n",
    "- `\"fusion_transformer\"`: transformer-based fusion for features from multiple backbones.\n",
    "\n",
    "If no data of one modality is detected, the related model types will be automatically removed in training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bc0c2e5",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.names\": [\"hf_text\", \"timm_image\", \"clip\", \"categorical_mlp\", \"numerical_mlp\", \"fusion_mlp\"]})\n",
    "# use only text models\n",
    "predictor.fit(hyperparameters={\"model.names\": [\"hf_text\"]})\n",
    "# use only image models\n",
    "predictor.fit(hyperparameters={\"model.names\": [\"timm_image\"]})\n",
    "# use only clip models\n",
    "predictor.fit(hyperparameters={\"model.names\": [\"clip\"]})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2d1c833",
   "metadata": {},
   "source": [
    "### model.hf_text.checkpoint_name\n",
    "\n",
    "Specify a text backbone supported by the Hugginface [AutoModel](https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#automodel)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27360756",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.checkpoint_name\": \"google/electra-base-discriminator\"})\n",
    "# choose roberta base\n",
    "predictor.fit(hyperparameters={\"model.hf_text.checkpoint_name\": \"roberta-base\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bff5c1bd",
   "metadata": {},
   "source": [
    "### model.hf_text.pooling_mode\n",
    "\n",
    "The feature pooling mode for transformer architectures.\n",
    "\n",
    "- `cls`: uses the cls feature vector to represent a sentence.\n",
    "- `mean`: averages all the token feature vectors to represent a sentence."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1359b199",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.pooling_mode\": \"cls\"})\n",
    "# using the mean pooling\n",
    "predictor.fit(hyperparameters={\"model.hf_text.pooling_mode\": \"mean\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87a0ac4c",
   "metadata": {},
   "source": [
    "### model.hf_text.tokenizer_name\n",
    "\n",
    "Choose the text tokenizer. It is recommended to use the default auto tokenizer.\n",
    "\n",
    "- `hf_auto`: the [Huggingface auto tokenizer](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer).\n",
    "- `bert`: the [BERT tokenizer](https://huggingface.co/docs/transformers/v4.21.1/en/model_doc/bert#transformers.BertTokenizer).\n",
    "- `electra`: the [ELECTRA tokenizer](https://huggingface.co/docs/transformers/v4.21.1/en/model_doc/electra#transformers.ElectraTokenizer).\n",
    "- `clip`: the [CLIP tokenizer](https://huggingface.co/docs/transformers/v4.21.1/en/model_doc/clip#transformers.CLIPTokenizer)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c50a4d84",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.tokenizer_name\": \"hf_auto\"})\n",
    "# using the tokenizer of the ELECTRA model\n",
    "predictor.fit(hyperparameters={\"model.hf_text.tokenizer_name\": \"electra\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8372c024",
   "metadata": {},
   "source": [
    "### model.hf_text.max_text_len\n",
    "\n",
    "Set the maximum text length. Different models may allow different maximum lengths. If `model.hf_text.max_text_len` > 0, we choose the minimum between `model.hf_text.max_text_len` and the maximum length allowed by the model. Setting `model.hf_text.max_text_len` <= 0 would use the model's maximum length."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1db84e23",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.max_text_len\": 512})\n",
    "# set to use the length allowed by the tokenizer.\n",
    "predictor.fit(hyperparameters={\"model.hf_text.max_text_len\": -1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67a57f56",
   "metadata": {},
   "source": [
    "### model.hf_text.insert_sep\n",
    "\n",
    "Whether to insert the SEP token between texts from different columns of a dataframe."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61c8f6b9",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.insert_sep\": True})\n",
    "# use no SEP token.\n",
    "predictor.fit(hyperparameters={\"model.hf_text.insert_sep\": False})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "525da692",
   "metadata": {},
   "source": [
    "### model.hf_text.text_segment_num\n",
    "\n",
    "How many text segments are used in a token sequence. Each text segment has one [token type ID](https://huggingface.co/transformers/v2.11.0/glossary.html#token-type-ids). We choose the minimum between `model.hf_text.text_segment_num` and the default used by the model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "519d778a",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_segment_num\": 2})\n",
    "# use 1 text segment\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_segment_num\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffe782fd",
   "metadata": {},
   "source": [
    "### model.hf_text.stochastic_chunk\n",
    "\n",
    "Whether to randomly cut a text chunk if a sample's text token number is larger than `model.hf_text.max_text_len`. If False, cut a token sequence from index 0 to the maximum allowed length. Otherwise, randomly sample a start index to cut a text chunk."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71f3f97e",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.stochastic_chunk\": False})\n",
    "# select a stochastic text chunk if a text sequence is over-long\n",
    "predictor.fit(hyperparameters={\"model.hf_text.stochastic_chunk\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d63c524",
   "metadata": {},
   "source": [
    "### model.hf_text.text_aug_detect_length\n",
    "\n",
    "Perform text augmentation only when the text token number is no less than `model.hf_text.text_aug_detect_length`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2538fbd3",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_aug_detect_length\": 10})\n",
    "# Allow text augmentation for texts whose token number is no less than 5\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_aug_detect_length\": 5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4755a0c",
   "metadata": {},
   "source": [
    "### model.hf_text.text_trivial_aug_maxscale\n",
    "\n",
    "Set the maximum percentage of text tokens to conduct data augmentation. For each text token sequence, we randomly sample a percentage in [0, `model.hf_text.text_trivial_aug_maxscale`] and one operation from four trivial augmentations, including synonym replacement, random word swap, random word deletion, and random punctuation insertion, to do text augmentation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b44d07f",
   "metadata": {},
   "source": [
    "```\n",
    "# by default, AutoMM doesn't do text augmentation\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_trivial_aug_maxscale\": 0})\n",
    "# Enable trivial augmentation by setting the max scale to 0.1\n",
    "predictor.fit(hyperparameters={\"model.hf_text.text_trivial_aug_maxscale\": 0.1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5019b9f4",
   "metadata": {},
   "source": [
    "### model.hf_text.gradient_checkpointing\n",
    "\n",
    "Whether to turn on gradient checkpointing to reduce the memory consumption for calculating gradients. For more about gradient checkpointing, feel free to refer to [relevant tutorials](https://github.com/cybertronai/gradient-checkpointing)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38476035",
   "metadata": {},
   "source": [
    "```\n",
    "# by default, AutoMM doesn't turn on gradient checkpointing\n",
    "predictor.fit(hyperparameters={\"model.hf_text.gradient_checkpointing\": False})\n",
    "# Turn on gradient checkpointing\n",
    "predictor.fit(hyperparameters={\"model.hf_text.gradient_checkpointing\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec6d7d31",
   "metadata": {},
   "source": [
    "### model.timm_image.checkpoint_name\n",
    "\n",
    "Select an image backbone from [TIMM](https://github.com/rwightman/pytorch-image-models/tree/master/timm/models)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20aa69bb",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.timm_image.checkpoint_name\": \"swin_base_patch4_window7_224\"})\n",
    "# choose a vit base\n",
    "predictor.fit(hyperparameters={\"model.timm_image.checkpoint_name\": \"vit_base_patch32_224\"})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### model.timm_image.train_transforms\n",
    "\n",
    "Augment images in training. Support passing a list of supported strings chosen from (`resize_to_square`, `resize_shorter_side`, `center_crop`, `random_resize_crop`, `random_horizontal_flip`, `random_vertical_flip`, `color_jitter`, `affine`, `randaug`, `trivial_augment`), or a list of callable and pickle-able transform objects. For example, you use the torchvision transforms (https://pytorch.org/vision/stable/transforms.html)."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.timm_image.train_transforms\": [\"resize_shorter_side\", \"center_crop\", \"trivial_augment\"]})\n",
    "# use random resize crop and random horizontal flip\n",
    "predictor.fit(hyperparameters={\"model.timm_image.train_transforms\": [\"random_resize_crop\", \"random_horizontal_flip\"]})\n",
    "# or use a list of callable and pickle-able objects, e.g., torchvision transforms\n",
    "predictor.fit(hyperparameters={\"model.timm_image.train_transforms\": [torchvision.transforms.RandomResizedCrop(224), torchvision.transforms.RandomHorizontalFlip()]})\n",
    "```"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### model.timm_image.val_transforms\n",
    "\n",
    "Transform images in validation/test/deployment. Similar to `model.timm_image.train_transforms`, support a list of strings or callable and pickle-able objects to transform images."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"model.timm_image.val_transforms\": [\"resize_shorter_side\", \"center_crop\"]})\n",
    "# resize image to square\n",
    "predictor.fit(hyperparameters={\"model.timm_image.val_transforms\": [\"resize_to_square\"]})\n",
    "# or use a list of callable and pickle-able objects, e.g., torchvision transforms\n",
    "predictor.fit(hyperparameters={\"model.timm_image.val_transforms\": [torchvision.transforms.Resize((224, 224)]})\n",
    "```\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "id": "690a7766",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "### data.image.missing_value_strategy\n",
    "\n",
    "How to deal with missing images, opening which fails.\n",
    "\n",
    "- `\"skip\"`: skip a sample with missing images.\n",
    "- `\"zero\"`: use zero image to replace a missing image."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed5ad640",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.image.missing_value_strategy\": \"zero\"})\n",
    "# skip the image\n",
    "predictor.fit(hyperparameters={\"data.image.missing_value_strategy\": \"skip\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b5a689b",
   "metadata": {},
   "source": [
    "### data.text.normalize_text\n",
    "Whether to normalize text with encoding problems. If True, TextProcessor will run through a series of encoding and decoding for text normalization. Please refer to the [Example](https://github.com/autogluon/autogluon/tree/master/examples/automm/kaggle_feedback_prize) of Kaggle competition for applying text normalization."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab6c46ad",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.text.normalize_text\": False})\n",
    "# turn on text normalization\n",
    "predictor.fit(hyperparameters={\"data.text.normalize_text\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "177eb155",
   "metadata": {},
   "source": [
    "### data.categorical.convert_to_text\n",
    "\n",
    "Whether to treat categorical data as text. If True, no categorical models, e.g., `\"categorical_mlp\"` and `\"categorical_transformer\"`, would be used."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "267ad0a9",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.categorical.convert_to_text\": True})\n",
    "# turn off the conversion\n",
    "predictor.fit(hyperparameters={\"data.categorical.convert_to_text\": False})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bf8d9e6",
   "metadata": {},
   "source": [
    "### data.numerical.convert_to_text\n",
    "\n",
    "Whether to convert numerical data to text. If True, no numerical models e.g., `\"numerical_mlp\"` and `\"numerical_transformer\"`, would be used."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f158d9a0",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.numerical.convert_to_text\": False})\n",
    "# turn on the conversion\n",
    "predictor.fit(hyperparameters={\"data.numerical.convert_to_text\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "daaa41c9",
   "metadata": {},
   "source": [
    "### data.numerical.scaler_with_mean\n",
    "\n",
    "If True, center the numerical data (not including the numerical labels) before scaling."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "984abb92",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.numerical.scaler_with_mean\": True})\n",
    "# turn off centering\n",
    "predictor.fit(hyperparameters={\"data.numerical.scaler_with_mean\": False})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "589677a4",
   "metadata": {},
   "source": [
    "### data.numerical.scaler_with_std\n",
    "\n",
    "If True, scale the numerical data (not including the numerical labels) to unit variance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cfca7db",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.numerical.scaler_with_std\": True})\n",
    "# turn off scaling\n",
    "predictor.fit(hyperparameters={\"data.numerical.scaler_with_std\": False})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9241360b",
   "metadata": {},
   "source": [
    "### data.label.numerical_label_preprocessing\n",
    "\n",
    "How to process the numerical labels in regression tasks.\n",
    "\n",
    "- `\"standardscaler\"`: standardizes numerical labels by removing the mean and scaling to unit variance.\n",
    "- `\"minmaxscaler\"`: transforms numerical labels by scaling each feature to range (0, 1)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bea7d018",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.label.numerical_label_preprocessing\": \"standardscaler\"})\n",
    "# scale numerical labels to (0, 1)\n",
    "predictor.fit(hyperparameters={\"data.label.numerical_label_preprocessing\": \"minmaxscaler\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "103f689a",
   "metadata": {},
   "source": [
    "### data.pos_label\n",
    "\n",
    "The positive label in a binary classification task. Users need to specify this label to properly use some metrics, e.g., roc_auc, average_precision, and f1."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a14b1af",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.pos_label\": None})\n",
    "# assume the labels are [\"changed\", \"not changed\"] and \"changed\" is the positive label\n",
    "predictor.fit(hyperparameters={\"data.pos_label\": \"changed\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3e3242f",
   "metadata": {},
   "source": [
    "### data.mixup.turn_on\n",
    "\n",
    "If True, use Mixup in training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0161ba46",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.turn_on\": False})\n",
    "# turn on Mixup\n",
    "predictor.fit(hyperparameters={\"data.mixup.turn_on\": True})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd97b924",
   "metadata": {},
   "source": [
    "### data.mixup.mixup_alpha\n",
    "\n",
    "Mixup alpha value. Mixup is active if `data.mixup.mixup_alpha` > 0."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd9bc14b",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.mixup_alpha\": 0.8})\n",
    "# set it to 1.0 to turn off Mixup\n",
    "predictor.fit(hyperparameters={\"data.mixup.mixup_alpha\": 1.0})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b48c9408",
   "metadata": {},
   "source": [
    "### data.mixup.cutmix_alpha\n",
    "\n",
    "Cutmix alpha value. Cutmix is active if `data.mixup.cutmix_alpha` > 0."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fc9b53a",
   "metadata": {},
   "source": [
    "```\n",
    "# by default, Cutmix is turned off by using alpha 1.0\n",
    "predictor.fit(hyperparameters={\"data.mixup.cutmix_alpha\": 1.0})\n",
    "# turn it on by choosing a number in range (0, 1)\n",
    "predictor.fit(hyperparameters={\"data.mixup.cutmix_alpha\": 0.8})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3a58751",
   "metadata": {},
   "source": [
    "### data.mixup.prob\n",
    "\n",
    "The probability of conducting Mixup or Cutmix if enabled."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "738cdcc7",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.prob\": 1.0})\n",
    "# set probability to 0.5\n",
    "predictor.fit(hyperparameters={\"data.mixup.prob\": 0.5})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5991c094",
   "metadata": {},
   "source": [
    "### data.mixup.switch_prob\n",
    "\n",
    "The probability of switching to Cutmix instead of Mixup when both are active."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d24393ef",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.switch_prob\": 0.5})\n",
    "# set probability to 0.7\n",
    "predictor.fit(hyperparameters={\"data.mixup.switch_prob\": 0.7})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab459677",
   "metadata": {},
   "source": [
    "### data.mixup.mode\n",
    "\n",
    "How to apply Mixup or Cutmix params (per `\"batch\"`, `\"pair\"` (pair of elements), `\"elem\"` (element)).\n",
    "See [here](https://github.com/rwightman/pytorch-image-models/blob/d30685c283137b4b91ea43c4e595c964cd2cb6f0/timm/data/mixup.py#L211-L216) for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ada57733",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.mode\": \"batch\"})\n",
    "# use \"pair\"\n",
    "predictor.fit(hyperparameters={\"data.mixup.mode\": \"pair\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "737d454d",
   "metadata": {},
   "source": [
    "### data.mixup.label_smoothing\n",
    "\n",
    "Apply label smoothing to the mixed label tensors."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40f1d216",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.label_smoothing\": 0.1})\n",
    "# set it to 0.2\n",
    "predictor.fit(hyperparameters={\"data.mixup.label_smoothing\": 0.2})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73c8a29d",
   "metadata": {},
   "source": [
    "### data.mixup.turn_off_epoch\n",
    "\n",
    "Stop Mixup or Cutmix after reaching this number of epochs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0c3715f",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM\n",
    "predictor.fit(hyperparameters={\"data.mixup.turn_off_epoch\": 5})\n",
    "# turn off mixup after 7 epochs\n",
    "predictor.fit(hyperparameters={\"data.mixup.turn_off_epoch\": 7})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fb33f07",
   "metadata": {},
   "source": [
    "## Distiller\n",
    "\n",
    "### distiller.soft_label_loss_type\n",
    "\n",
    "What loss to compute when using teacher's output (logits) to supervise student's."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3ca2c3d",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM for classification\n",
    "predictor.fit(hyperparameters={\"distiller.soft_label_loss_type\": \"cross_entropy\"})\n",
    "# default used by AutoMM for regression\n",
    "predictor.fit(hyperparameters={\"distiller.soft_label_loss_type\": \"mse\"})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c91287a0",
   "metadata": {},
   "source": [
    "### distiller.temperature\n",
    "\n",
    "Before computing the soft label loss, scale the teacher and student logits with it (teacher_logits / temperature, student_logits / temperature)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f67e3e1",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM for classification\n",
    "predictor.fit(hyperparameters={\"distiller.temperature\": 5})\n",
    "# set temperature to 1\n",
    "predictor.fit(hyperparameters={\"distiller.temperature\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f95727c",
   "metadata": {},
   "source": [
    "### distiller.hard_label_weight\n",
    "\n",
    "Scale the student's hard label (groundtruth) loss with this weight (hard_label_loss \\* hard_label_weight)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f5d5eca",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM for classification\n",
    "predictor.fit(hyperparameters={\"distiller.hard_label_weight\": 0.2})\n",
    "# set not to scale the hard label loss\n",
    "predictor.fit(hyperparameters={\"distiller.hard_label_weight\": 1})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ebc0b75",
   "metadata": {},
   "source": [
    "### distiller.soft_label_weight\n",
    "\n",
    "Scale the student's soft label (teacher's output) loss with this weight (soft_label_loss \\* soft_label_weight)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b3b90c2",
   "metadata": {},
   "source": [
    "```\n",
    "# default used by AutoMM for classification\n",
    "predictor.fit(hyperparameters={\"distiller.soft_label_weight\": 50})\n",
    "# set not to scale the soft label loss\n",
    "predictor.fit(hyperparameters={\"distiller.soft_label_weight\": 1})\n",
    "```\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
